// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*
 * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar,
 * Pradeep Ramani, Tri Dao. Licensed under the BSD 3-Clause.
 *
 * Modified by the FlashInfer team.
 */

#ifndef ATTENTION_HOPPER_PREFILL_SM90_CUH_
#define ATTENTION_HOPPER_PREFILL_SM90_CUH_

#include <cuda.h>
#include <cuda_device_runtime_api.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>

#include <type_traits>
#include <vector>

#include "attention_updater.cuh"
#include "cute/tensor.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "epilogue.cuh"
#include "helper.h"
#include "kernel_traits.cuh"
#include "mainloop_load.cuh"
#include "mainloop_mma.cuh"
#include "utils.cuh"

#ifdef DEBUG_MLA
#undef DEBUG_MLA
#endif
// #define DEBUG_MLA

namespace mla_attn {

using namespace cute;

template <typename DTypeQ_,
          typename DTypeKV_,
          typename DTypeO_,
          typename IdType_>
struct Params {
  using DTypeQ = DTypeQ_;
  using DTypeKV = DTypeKV_;
  using DTypeO = DTypeO_;
  using IdType = IdType_;

  alignas(16) DTypeQ *Q;      // [token_num, head_num, dim_head]
  alignas(16) DTypeKV *KV;    // [max_block_num, block_size, dim_head]
  alignas(16) DTypeO *O;      // [token_num, head_num, dim_head]
  alignas(16) DTypeO *O_tmp;  // [max_num_chunks, bsz, head_num, dim_head]
  alignas(
      16) float *m;  // [max_num_chunks, bsz * max_draft_token_num * head_num]
  alignas(
      16) float *d;  // [max_num_chunks, bsz * max_draft_token_num * head_num]

  alignas(16) IdType *block_tables;
  alignas(16) IdType *seq_lens_this_time;
  alignas(16) IdType *seq_lens_decoder;
  alignas(16) IdType *cumsum_q_seqlens;
  alignas(16) IdType *batch_id_per_token;

  alignas(16) IdType *batch_ids;
  alignas(16) IdType *tile_ids_per_batch;
  alignas(16) IdType *num_blocks_x;
  alignas(16) IdType *chunk_size_device;

  uint32_t q_stride_bsz;
  uint32_t q_stride_head_num;

  uint32_t kv_stride_block_num;
  uint32_t kv_stride_block_size;

  uint32_t o_stride_bsz;
  uint32_t o_stride_head_num;

  int bsz;
  int token_num;
  int max_block_num;
  int max_block_num_per_seq;
  int q_num_head;
  int qk_head_dim;
  int vo_head_dim;
  int block_size;
  int max_draft_token_num;
  int chunk_num;

  float sm_scale;
};

#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...)  \
  if (group_size == 8) {                                  \
    constexpr size_t GROUP_SIZE = 8;                      \
    __VA_ARGS__                                           \
  } else if (group_size == 16) {                          \
    constexpr size_t GROUP_SIZE = 16;                     \
    __VA_ARGS__                                           \
  } else if (group_size == 64) {                          \
    constexpr size_t GROUP_SIZE = 64;                     \
    __VA_ARGS__                                           \
  } else if (group_size == 128) {                         \
    constexpr size_t GROUP_SIZE = 128;                    \
    __VA_ARGS__                                           \
  } else {                                                \
    PD_THROW("not support the group_size: ", group_size); \
    return cudaErrorNotSupported;                         \
  }

template <typename CollectiveMainloop,
          typename CollectiveEpilogue,
          typename Ktraits,
          bool CAUSAL,
          int SM_COUNT = 132,
          bool USE_REG_EALLOC = false,
          bool USE_FIXED_BLOCK = true>
__global__ void __launch_bounds__(
    Ktraits::NUM_WARPS *cutlass::NumThreadsPerWarp, 1)
    MLAWithKVCacheKernel(
        CUTE_GRID_CONSTANT
        typename CollectiveMainloop::Params const mainloop_params,
        CUTE_GRID_CONSTANT
        typename CollectiveEpilogue::Params const epilogue_params) {
  using DTypeQ = typename Ktraits::DTypeQ;
  using DTypeKV = typename Ktraits::DTypeKV;
  using DTypeO = typename Ktraits::DTypeO;
  using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
  using TileShape_QKD = typename Ktraits::TileShape_QKD;
  using TileShape_PDV = typename Ktraits::TileShape_PDV;

  static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
  static constexpr int NUM_COPY_THREADS = Ktraits::NUM_PRODUCER_THREADS;
  static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
  static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
  const int num_blocks_x = mainloop_params.num_blocks_x[0];
  const int chunk_size = mainloop_params.chunk_size_device[0];

  static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;

  using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
  using PipelineParams = typename MainloopPipeline::Params;
  using PipelineState = typename MainloopPipeline::PipelineState;

  using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ;
  using PipelineParamsQ = typename MainloopPipelineQ::Params;
  using PipelineStateQ = typename MainloopPipelineQ::PipelineState;

  extern __shared__ char shared_memory[];
  auto &shared_storage =
      *reinterpret_cast<typename Ktraits::SharedStorage *>(shared_memory);

  int const lane_predicate = cute::elect_one_sync();
  int const warp_idx = cutlass::canonical_warp_idx_sync();

  if (warp_idx == 0 && lane_predicate) {
    CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
    CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
  }

  // Obtain warp index
  int const warp_group_thread_idx =
      threadIdx.x % cutlass::NumThreadsPerWarpGroup;

  PipelineParams pipeline_params;
  int warp_group_idx = cutlass::canonical_warp_group_idx();
  pipeline_params.role = warp_group_idx == 0
                             ? MainloopPipeline::ThreadCategory::Producer
                             : MainloopPipeline::ThreadCategory::Consumer;
  if constexpr (use_tma_load_kv) {
    pipeline_params.is_leader = warp_group_thread_idx == 0;
    pipeline_params.num_consumers = NUM_MMA_THREADS;
  } else {
    pipeline_params.producer_arv_count = NUM_COPY_THREADS;
    pipeline_params.consumer_arv_count = NUM_MMA_THREADS;
  }

  PipelineParamsQ pipeline_params_q;
  pipeline_params_q.role = warp_group_idx == 0
                               ? MainloopPipelineQ::ThreadCategory::Producer
                               : MainloopPipelineQ::ThreadCategory::Consumer;
  pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
  pipeline_params_q.consumer_arv_count =
      cutlass::NumThreadsPerWarpGroup;  // just one wg qk

  MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
  MainloopPipeline pipeline_kv = [&] {
    if constexpr (use_tma_load_kv) {
      pipeline_params.transaction_bytes =
          CollectiveMainloop::TmaTransactionBytesKV;
      return MainloopPipeline(shared_storage.pipeline_kv,
                              pipeline_params,
                              /*cluster_shape=*/Shape<_1, _1, _1>{});
    } else {
      return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
    }
  }();
  __syncthreads();

  CollectiveMainloop collective_mainloop;
  CollectiveEpilogue collective_epilogue;

  if (warp_group_idx == 0) {
    // producer
    if constexpr (USE_REG_EALLOC) {
      cutlass::arch::warpgroup_reg_dealloc<72>();
    }
    const uint32_t warp_idx_in_warpgroup =
        __shfl_sync(0xffffffff, warp_idx % 4, 0);

    PipelineStateQ smem_pipe_write_q =
        cutlass::make_producer_start_state<MainloopPipelineQ>();
    PipelineState smem_pipe_write_kv =
        cutlass::make_producer_start_state<MainloopPipeline>();
    for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
      const int bid = mainloop_params.batch_ids[i];
      const int tile_id = mainloop_params.tile_ids_per_batch[i];
      const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
      const int seq_len_decoder_now =
          mainloop_params.seq_lens_decoder[bid] + seq_len_now;
      const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
      cutlass::arch::NamedBarrier::sync(
          Ktraits::NUM_THREADS,
          /*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));

      // load Q
      collective_mainloop.load_q(mainloop_params,
                                 pipeline_q,
                                 smem_pipe_write_q,
                                 shared_storage,
                                 threadIdx.x,
                                 bid);

      if constexpr (!use_tma_load_kv) {
        // load kv
        collective_mainloop.load_kv(mainloop_params,
                                    pipeline_kv,
                                    smem_pipe_write_kv,
                                    shared_storage,
                                    bid,
                                    seq_len_decoder_now,
                                    tile_id);
      } else {
        if (warp_idx_in_warpgroup == 0) {
          // load kv tma
          collective_mainloop.load_kv_tma(mainloop_params,
                                          pipeline_kv,
                                          smem_pipe_write_kv,
                                          shared_storage,
                                          bid,
                                          seq_len_decoder_now,
                                          tile_id);
        }
      }
    }
  } else {
    // consumer
    if constexpr (USE_REG_EALLOC) {
      cutlass::arch::warpgroup_reg_alloc<216>();
    }
    PipelineStateQ smem_pipe_read_q;
    PipelineState smem_pipe_read_kv;

    typename Ktraits::TiledMmaPVSS tiled_mma_pv;
    Tensor tOrO =
        partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));

    auto attention_updater =
        OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(
            mainloop_params.sm_scale);
    for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
      clear(tOrO);
      clear(attention_updater.scores_scale);
      const int bid = mainloop_params.batch_ids[i];
      const int tile_id = mainloop_params.tile_ids_per_batch[i];
      const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
      const int seq_len_decoder_now =
          mainloop_params.seq_lens_decoder[bid] + seq_len_now;
      const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
      cutlass::arch::NamedBarrier::sync(
          Ktraits::NUM_THREADS,
          /*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));

      if constexpr (BLOCK_SHAPE_KV == 64) {
        mma_f16<Ktraits, CAUSAL>(mainloop_params,
                                 pipeline_q,
                                 smem_pipe_read_q,
                                 pipeline_kv,
                                 smem_pipe_read_kv,
                                 tOrO,
                                 attention_updater,
                                 threadIdx.x - NUM_COPY_THREADS,
                                 bid,
                                 seq_len_decoder_now,
                                 seq_len_now,
                                 tile_id,
                                 shared_storage);
      } else if (BLOCK_SHAPE_KV == 32) {
        mma_f16_two_stages<Ktraits, CAUSAL>(mainloop_params,
                                            pipeline_q,
                                            smem_pipe_read_q,
                                            pipeline_kv,
                                            smem_pipe_read_kv,
                                            tOrO,
                                            attention_updater,
                                            threadIdx.x - NUM_COPY_THREADS,
                                            bid,
                                            seq_len_decoder_now,
                                            seq_len_now,
                                            tile_id,
                                            shared_storage);
      }

      collective_epilogue.store(epilogue_params,
                                tOrO,
                                attention_updater.get_lse(),
                                shared_storage,
                                tiled_mma_pv,
                                threadIdx.x - NUM_COPY_THREADS,
                                bid,
                                mainloop_params.bsz,
                                seq_len_now,
                                start_token_idx,
                                tile_id,
                                seq_len_decoder_now,
                                chunk_size,
                                mainloop_params.max_draft_token_num,
                                mainloop_params.o_stride_bsz);
    }
  }
}

template <typename KernelTraits,
          bool CAUSAL,
          typename Params,
          bool USE_REG_EALLOC = false,
          bool USE_FIXED_BLOCK = true>
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(
    Params &params, cudaStream_t stream) {
  using DTypeQ = typename KernelTraits::DTypeQ;
  using DTypeKV = typename KernelTraits::DTypeKV;
  using DTypeO = typename KernelTraits::DTypeO;
  using IdType = typename KernelTraits::IdType;
  using NV_TYPE = typename KernelTraits::NV_TYPE;

  using CollectiveMainloop = CollectiveMainloop<KernelTraits, CAUSAL>;
  using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;

  typename CollectiveMainloop::Params mainloop_params =
      CollectiveMainloop::to_underlying_arguments(
          {make_layout(
               make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim),
               make_stride(params.qk_head_dim, _1{})),  // layout q
           make_layout(
               make_shape(
                   params.block_size, params.qk_head_dim, params.max_block_num),
               make_stride(params.qk_head_dim,
                           _1{},
                           params.block_size * params.qk_head_dim)),
           make_layout(make_shape(params.chunk_num,
                                  params.bsz * params.max_draft_token_num *
                                      params.q_num_head),
                       make_stride(params.bsz * params.max_draft_token_num *
                                       params.q_num_head,
                                   _1{})),
           params.Q,
           params.KV,
           params.m,
           params.d,
           params.block_tables,
           params.seq_lens_this_time,
           params.seq_lens_decoder,
           params.cumsum_q_seqlens,
           params.batch_ids,
           params.tile_ids_per_batch,
           params.num_blocks_x,
           params.chunk_size_device,
           params.sm_scale,
           params.bsz,
           params.max_block_num,
           params.max_block_num_per_seq,
           params.q_stride_bsz,
           params.q_stride_head_num,
           params.kv_stride_block_num,
           params.kv_stride_block_size,
           params.o_stride_bsz,
           params.o_stride_head_num,
           params.chunk_num,
           params.max_draft_token_num});
  typename CollectiveEpilogue::Params epilogue_params =
      CollectiveEpilogue::to_underlying_arguments_ntma({
          params.O,
          make_layout(
              make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim),
              make_stride(params.vo_head_dim, _1{})),  // layout O
          params.O_tmp,
          make_layout(
              make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim),
              make_stride(params.vo_head_dim, _1{}))  // layout O_tmp
      });

  // Get the ptr to kernel function.
  auto kernel = MLAWithKVCacheKernel<CollectiveMainloop,
                                     CollectiveEpilogue,
                                     KernelTraits,
                                     CAUSAL,
                                     132>;
  int smem_size = sizeof(typename KernelTraits::SharedStorage);
  cudaFuncSetAttribute(
      kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
  int device;
  cudaGetDevice(&device);
  int multiprocessor_count;
  cudaDeviceGetAttribute(
      &multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
  int act_blocks_per_sm;
  cudaOccupancyMaxActiveBlocksPerMultiprocessor(
      &act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);

  // NOTE: (changwenbin) Here the grid size is fixed so that MLA can be captured
  // by the graph.
  dim3 grid_dims = {multiprocessor_count, 1, 1};
  static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
  dim3 block_dims(ctaSize, 1, 1);
  kernel<<<grid_dims, block_dims, smem_size, stream>>>(mainloop_params,
                                                       epilogue_params);
  if (params.chunk_num > 1) {
    constexpr int vec_size = 16 / sizeof(DTypeO);
    constexpr int merge_block_size = 256;
    constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
    constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
    dim3 grids_merge(multiprocessor_count,
                     params.q_num_head);  // 128k is too large
    dim3 blocks_merge(blockx, blocky);
    merge_multi_chunks_kernel<NV_TYPE,
                              vec_size,
                              blocky,
                              KernelTraits::HEAD_DIM_VO>
        <<<grids_merge, blocks_merge, 0, stream>>>(
            reinterpret_cast<NV_TYPE *>(params.O_tmp),
            params.m,
            params.d,
            params.seq_lens_this_time,
            params.seq_lens_decoder,
            params.cumsum_q_seqlens,
            params.batch_id_per_token,
            params.chunk_size_device,
            reinterpret_cast<NV_TYPE *>(params.O),
            params.q_num_head,
            params.vo_head_dim,
            params.token_num,
            params.bsz,
            params.max_draft_token_num);
  }
  return cudaSuccess;
}

template <uint32_t HEAD_DIM_QK,
          uint32_t HEAD_DIM_VO,
          typename NV_TYPE,
          typename Params,
          bool USE_REG_EALLOC = false,
          bool USE_FIXED_BLOCK = true>
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params &params,
                                               cudaStream_t stream) {
  constexpr bool CAUSAL = true;
  if constexpr (HEAD_DIM_QK == 576) {
    DISPATCH_GROUP_SIZE(params.q_num_head,
                        GROUP_SIZE,
                        BatchMLAWithPagedKVCacheKernelTraitsDispatched<
                            AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
                                                  HEAD_DIM_QK,
                                                  HEAD_DIM_VO,
                                                  GROUP_SIZE,
                                                  /*BLOCK_SHAPE_Q_=*/64,
                                                  /*BLOCK_SHAPE_KV_=*/64,
                                                  /*NUM_STAGES_=*/2,
                                                  typename Params::DTypeQ,
                                                  typename Params::DTypeKV,
                                                  typename Params::DTypeO,
                                                  typename Params::IdType,
                                                  NV_TYPE>,
                            CAUSAL,
                            Params,
                            USE_REG_EALLOC,
                            USE_FIXED_BLOCK>(params, stream);)
  } else {
    return cudaErrorNotSupported;
  }
  return cudaSuccess;
};

}  // namespace mla_attn

#endif  // ATTENTION_HOPPER_PREFILL_SM90_CUH_
